
"""python_type_tokenizer.py — FINAL (round‑trip safe)

• Scalars tagged: <INT>-7, <FLOAT>3.14, <BOOL>True, <STR>'hello'
• Containers tagged with open/close: <LIST>[ … <LIST>] , <TUPLE>( … <TUPLE>) , <TUPLE>().
• Tokenization removes quotes *only* inside tokens (model sees <STR>hello).
• detag_text(tag_text(s)) == s for all included examples.
"""

from __future__ import annotations
import ast, io, re, tokenize as py_tok
from typing import List, Tuple

__all__ = ["PyTypeTokenizer"]

# Scalar tag map
_CONST_TAG = {int:"<INT>", float:"<FLOAT>", bool:"<BOOL>", str:"<STR>"}
ALL_TAGS   = list(_CONST_TAG.values()) + ["<LIST>", "<TUPLE>"]

# Regex helpers
_TAG_RE    = re.compile(r"<[^>]+>")
_MINUS_FIX = re.compile(r"-(<INT>|<FLOAT>)(?=[0-9])")

_LIST_RE   = re.compile(r"\[[^\[\]]*?\]")
_TUPLE_RE  = re.compile(r"\([^()]*?,[^()]*?\)")
_EMPTY_TUP = re.compile(r"\(\)")

_SPLIT_RE = re.compile(
    r"<TUPLE>\(\)"                               # empty tuple
    r"|<BOOL>True|<BOOL>False"                   # booleans
    r"|<[A-Z]+>[-+]?\d+\.\d+(?:e[-+]?\d+)?"      # floats
    r"|<[A-Z]+>[-+]?\d+"                         # ints
    r"|<[A-Z]+>'[^']*'|<[A-Z]+>\"[^\"]*\""       # string literals with quotes
    r"|<(?:LIST|TUPLE)>\[|<(?:LIST|TUPLE)>\(|<(?:LIST|TUPLE)>\]|<(?:LIST|TUPLE)>\)"  # container markers
    r"|<[^>]+>"                                  # fallback tag
    r"|[A-Za-z_][A-Za-z0-9_]*"                   # identifiers
    r"|[-+*/%^=(){}\[\].?:]"                     # punctuation
)

class PyTypeTokenizer:
    """Inline datatype tagging + tokenization."""

    # --------------------------- tag_text ---------------------------
    def tag_text(self, text: str) -> str:
        spans: List[Tuple[int,int,str]] = []
        buf = io.BytesIO(text.encode())
        prev = None
        try:
            for tok in py_tok.tokenize(buf.readline):
                ttype,tstr,(srow,scol),(erow,ecol),_ = tok
                # fold unary minus with number
                if prev and prev.type==py_tok.OP and prev.string=='-' and ttype==py_tok.NUMBER:
                    scol = prev.start[1]
                    tstr = '-' + tstr
                    prev = None
                else:
                    prev = tok
                tag = None
                if ttype == py_tok.NUMBER:
                    try: tag = _CONST_TAG[type(ast.literal_eval(tstr))]
                    except Exception: pass
                elif ttype == py_tok.STRING:
                    tag = "<STR>"
                elif ttype == py_tok.NAME and tstr in ("True","False"):
                    tag = "<BOOL>"
                if tag:
                    spans.append((scol, ecol, tag + tstr))
        except py_tok.TokenError:
            pass

        chars = list(text)
        for s,e,tagged_literal in reversed(spans):
            chars[s:e] = [tagged_literal]
        tagged = "".join(chars)

        # post‑process numbers
        tagged = _MINUS_FIX.sub(lambda m: f"{m.group(1)}-", tagged)

        # containers
        tagged = _LIST_RE.sub(lambda m: "<LIST>[" + m.group(0)[1:-1] + "<LIST>]", tagged)
        tagged = _TUPLE_RE.sub(lambda m: "<TUPLE>(" + m.group(0)[1:-1] + "<TUPLE>)", tagged)
        tagged = _EMPTY_TUP.sub("<TUPLE>()", tagged)

        return tagged

    # --------------------------- detag_text -------------------------
    def detag_text(self, s: str) -> str:
        return _TAG_RE.sub("", s)

    # --------------------------- tokenize ---------------------------
    def tokenize(self, s: str, *, pretagged: bool=False):
        text = s if pretagged else self.tag_text(s)
        raw = [tok for tok in _SPLIT_RE.findall(text) if tok != ',']

        # strip quotes inside <STR> tokens
        cleaned=[]
        for tok in raw:
            if tok.startswith("<STR>"):
                lit = tok[5:]
                if lit and lit[0] in ("'", '"') and lit[-1]==lit[0]:
                    lit = lit[1:-1]
                cleaned.append("<STR>"+lit)
            else:
                cleaned.append(tok)
        return cleaned

    __call__ = tag_text

    # --------------------------- register --------------------------
    @staticmethod
    def register_tokenizer(hf_tok, extra=None):
        hf_tok.add_tokens(ALL_TAGS + (extra or []), special_tokens=False)
        return hf_tok

# ------------------------------------------------------------------
if __name__=="__main__":
    tok = PyTypeTokenizer()
    tests = [
        "Replace 'cat' with 'dog' in \"concatenate\".",
        "Is () considered empty?"
    ]
    for t in tests:
        tg = tok.tag_text(t)
        print("TAG:", tg)
        print("TOK:", tok.tokenize(tg, pretagged=True))
        assert tok.detag_text(tg) == t
